from ...smp import *
from .multiple_choice import extract_answer_from_item
import numpy as np
import re

FAIL_MSG = "Failed to obtain answer via API."

DURATIONS = [
    "short",
    "medium",
    "long",
]

DOMAINS = [
    "Knowledge",
    "Film & Television",
    "Sports Competition",
    "Artistic Performance",
    "Life Record",
    "Multilingual",
]

SUB_CATEGORIES = [
    "Humanity & History",
    "Literature & Art",
    "Biology & Medicine",
    "Finance & Commerce",
    "Astronomy",
    "Geography",
    "Law",
    "Life Tip",
    "Technology",
    "Animation",
    "Movie & TV Show",
    "Documentary",
    "News Report",
    "Esports",
    "Basketball",
    "Football",
    "Athletics",
    "Other Sports",
    "Stage Play",
    "Magic Show",
    "Variety Show",
    "Acrobatics",
    "Handicraft",
    "Food",
    "Fashion",
    "Daily Life",
    "Travel",
    "Pet & Animal",
    "Exercise",
    "Multilingual",
]

TASK_CATEGORIES = [
    "Temporal Perception",
    "Spatial Perception",
    "Attribute Perception",
    "Action Recognition",
    "Object Recognition",
    "OCR Problems",
    "Counting Problem",
    "Temporal Reasoning",
    "Spatial Reasoning",
    "Action Reasoning",
    "Object Reasoning",
    "Information Synopsis",
]


def get_dimension_rating(data_path):
    data = load(data_path)

    duration_rating = {k: {} for k in DURATIONS}
    for duration in DURATIONS + ["overall"]:
        duration_rating[duration] = {
            "overall": "",
            "domain": {k: [] for k in DOMAINS},
            "sub_category": {k: [] for k in SUB_CATEGORIES},
            "task_type": {k: [] for k in TASK_CATEGORIES},
        }

    for i in range(len(data)):

        domain = data.iloc[i]["domain"]
        sub_ctg = data.iloc[i]["sub_category"]
        task_ctg = data.iloc[i]["task_type"]

        duration = data.iloc[i]["duration"]
        duration_rating[duration]["domain"][domain].append(data.iloc[i]["score"])
        duration_rating[duration]["sub_category"][sub_ctg].append(data.iloc[i]["score"])
        duration_rating[duration]["task_type"][task_ctg].append(data.iloc[i]["score"])

        duration_rating["overall"]["domain"][domain].append(data.iloc[i]["score"])
        duration_rating["overall"]["sub_category"][sub_ctg].append(
            data.iloc[i]["score"]
        )
        duration_rating["overall"]["task_type"][task_ctg].append(data.iloc[i]["score"])

    for duration in DURATIONS + ["overall"]:

        overall_res_dur = f'{np.mean([x for x in sum(duration_rating[duration]["domain"].values(), []) if x >= 0]):.3f}'
        duration_rating[duration]["overall"] = overall_res_dur

        for domain in DOMAINS:
            domain_res_dur = f'{np.mean([x for x in duration_rating[duration]["domain"][domain] if x >= 0]):.3f}'
            duration_rating[duration]["domain"][domain] = domain_res_dur

        for sub_ctg in SUB_CATEGORIES:
            sub_res_dur = f'{np.mean([x for x in duration_rating[duration]["sub_category"][sub_ctg] if x >= 0]):.3f}'
            duration_rating[duration]["sub_category"][sub_ctg] = sub_res_dur

        for task_ctg in TASK_CATEGORIES:
            task_res_dur = f'{np.mean([x for x in duration_rating[duration]["task_type"][task_ctg] if x >= 0]):.3f}'
            duration_rating[duration]["task_type"][task_ctg] = task_res_dur

    return duration_rating


def extract_option(model, input_item, dataset_name):
    options = input_item["question"].split("\n")[1:]
    for id, option in enumerate(options):
        option_id = chr(ord("A") + id) + "."
        if option.find(option_id) >= 0:
            input_item[chr(ord("A") + id)] = option[
                option.find(option_id) + len(option_id) :
            ].strip(". \n")
    return extract_answer_from_item(model, input_item, dataset_name)["opt"]


def extract_characters_regex(s):
    s = s.strip()
    answer_prefixes = [
        "The best answer is",
        "The correct answer is",
        "The answer is",
        "The answer",
        "The best option is" "The correct option is",
        "Best answer:" "Best option:",
        "Answer:",
        "Option:",
    ]
    for answer_prefix in answer_prefixes:
        s = s.replace(answer_prefix, "")

    if len(s.split()) > 10 and not re.search("[ABCD]", s):
        return ""
    matches = re.search(r"[ABCD]", s)
    if matches is None:
        return ""
    return matches[0]
